import os
import logging
import pandas as pd


# ==============================================================================================================
# HyperParameters
# ==============================================================================================================
BATCH_SIZE = 64
EPOCHS = 800
LR = 5e-5

# ==============================================================================================================
# Logger
# ==============================================================================================================
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger()


# ==============================================================================================================
# Generated filedir & filepath
# ==============================================================================================================
root = os.path.dirname(__file__)
origin_file_dir = os.path.join(root, "origin")
data_file_dir = os.path.join(root, "data")
image_file_dir = os.path.join(root, "image")

# file directories
boxplot_image_file_dir = os.path.join(image_file_dir, "boxplot")
histogram_image_file_dir = os.path.join(image_file_dir, "histogram")
correlation_image_file_dir = os.path.join(image_file_dir, "correlation")
regression_image_file_dir = os.path.join(image_file_dir, "regression")
# data file save paths
extract_data_file_path = os.path.join(data_file_dir, "extract_data.npy")
feature_data_file_path = os.path.join(data_file_dir, "feature_data.npy")
washed_extract_data_file_path = os.path.join(data_file_dir, "washed_extract_data.npy")
washed_feature_data_file_path = os.path.join(data_file_dir, "washed_feature_data.npy")
excel_result_file_path = os.path.join(data_file_dir, "results.xlsx")
excel_washed_result_file_path = os.path.join(data_file_dir, "washed_results.xlsx")
# model file save paths
nn_model_checkpoint_file_path = os.path.join(
    data_file_dir, f"nn_checkpoint_E{EPOCHS}.pth"
)
# svm_model_checkpoint_file_path = os.path.join(data_file_dir, f'svm_checkpoint.pth')
# elm_model_checkpoint_file_path = os.path.join(data_file_dir, f'elm_checkpoint.pth')
# Initialization
if not os.path.exists(data_file_dir):
    os.mkdir(data_file_dir)
if not os.path.exists(image_file_dir):
    os.mkdir(image_file_dir)
if not os.path.exists(boxplot_image_file_dir):
    os.mkdir(boxplot_image_file_dir)
if not os.path.exists(histogram_image_file_dir):
    os.mkdir(histogram_image_file_dir)
if not os.path.exists(correlation_image_file_dir):
    os.mkdir(correlation_image_file_dir)
if not os.path.exists(regression_image_file_dir):
    os.mkdir(regression_image_file_dir)

# ==============================================================================================================
# Data filedir & filepath
# ==============================================================================================================
data_file_dir_list = [
    os.path.join(origin_file_dir, "2020年"),
    os.path.join(origin_file_dir, "2021年"),
]
# Initialization
data_file_path_list = []
for date_file_dir in data_file_dir_list:
    data_file_path_list.extend(
        [
            os.path.join(date_file_dir, file_name)
            for file_name in os.listdir(date_file_dir)
        ]
    )

# ==============================================================================================================
# Features
# ==============================================================================================================
feature_dict = {
    "原矿": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "FeO": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "精矿": {"type": "in_feature", "start": 20, "interval": 4, "total": 6},
    "综尾": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "二段混磁精": {"type": "out_feature", "start": 20, "interval": 4, "total": 6},
    "强尾": {"type": "in_feature", "start": 20, "interval": 4, "total": 6},
    "粒度1#": {"type": "in_feature", "start": 22, "interval": 4, "total": 6},
    "粒度2#": {"type": "in_feature", "start": 22, "interval": 4, "total": 6},
    "正浮精": {"type": "out_feature", "start": 20, "interval": 2, "total": 12},
    "正浮尾": {"type": "out_feature", "start": 20, "interval": 4, "total": 6},
    "反浮精": {"type": "out_feature", "start": None, "interval": None, "total": 1},
    "反浮尾": {"type": "out_feature", "start": None, "interval": None, "total": 1},
    "反一扫精": {"type": "out_feature", "start": 20, "interval": 2, "total": 12},
    "浓度1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "浓度2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "流量1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "流量2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "KS-2粗1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "KS-2粗2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "KS-2精选": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "NaOH1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "NaOH2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "淀粉1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "淀粉2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "CaO1#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "CaO2#": {"type": "in_feature", "start": 20, "interval": 2, "total": 12},
    "激磁电流": {"type": "in_feature", "start": None, "interval": None, "total": 1},
}
dropped_feature = ["粒度1#", "粒度2#", "激磁电流"]
washed_feature_dict = feature_dict.copy()
for dropped_featname in dropped_feature:
    washed_feature_dict.pop(dropped_featname)

# Initialization
feature_name_list = list(feature_dict.keys())
washed_feature_name_list = [
    key for key in feature_dict.keys() if key not in dropped_feature
]
in_feature_index = [
    i
    for i, key in enumerate(washed_feature_dict)
    if washed_feature_dict[key]["type"] == "in_feature"
]
out_feature_index = [
    i
    for i, key in enumerate(washed_feature_dict)
    if washed_feature_dict[key]["type"] == "out_feature"
]
in_feature_name_list = [
    key
    for key in washed_feature_dict
    if washed_feature_dict[key]["type"] == "in_feature"
]
out_feature_name_list = [
    key
    for key in washed_feature_dict
    if washed_feature_dict[key]["type"] == "out_feature"
]
